# -*- coding:utf-8 -*-
from __future__ import print_function
import paddle.fluid as fluid
from utility import add_arguments, print_arguments, to_lodtensor, get_ctc_feeder_data, get_attention_feeder_for_infer
import paddle.fluid.profiler as profiler
from crnn_ctc_model import ctc_infer
from attention_model import attention_infer
import Tkinter as tk
import numpy as np
import data_reader
import argparse
import functools
import os
import time
import cv2

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable


add_arg('model',              str,   "crnn_ctc",   "Which type of network to be used. 'crnn_ctc' or 'attention'")
add_arg('model_path',         str,  './models/ctc_model_50000',   "The model path to be used for inference.")
add_arg('input_images_dir',   str,  None,   "The directory of images.")
add_arg('input_images_list',  str,  None,   "The list file of images.")
add_arg('dict',               str,  None,   "The dictionary. The result of inference will be index sequence if the dictionary was None.")
add_arg('use_gpu',            bool,  False,      "Whether use GPU to infer.")
add_arg('iterations',         int,  0,      "The number of iterations. Zero or less means whole test set. More than 0 means the test set might be looped until # of iterations is reached.")
add_arg('profile',            bool, False,  "Whether to use profiling.")
add_arg('skip_batch_num',     int,  0,      "The number of first minibatches to skip as warm-up for better performance test.")
add_arg('batch_size',         int,  1,      "The minibatch size.")

# yapf: enable
args = parser.parse_args()
print_arguments(args)
if args.model == "crnn_ctc":
    infer = ctc_infer
    get_feeder_data = get_ctc_feeder_data
else:
    infer = attention_infer
    get_feeder_data = get_attention_feeder_for_infer
eos = 1
sos = 0
num_classes = data_reader.num_classes()
data_shape = data_reader.data_shape()
# define network
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
ids = infer(images, num_classes, use_cudnn=True if args.use_gpu else False)
# data reader
infer_reader = data_reader.inference(
    batch_size=args.batch_size,
    infer_images_dir=args.input_images_dir,
    infer_list_file=args.input_images_list,
    cycle=True if args.iterations > 0 else False,
    model=args.model)
# prepare environment
place = fluid.CPUPlace()
if args.use_gpu:
    place = fluid.CUDAPlace(0)

exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())

# load dictionary
dict_map = None
if args.dict is not None and os.path.isfile(args.dict):
    dict_map = {}
    with open(args.dict) as dict_file:
        for i, word in enumerate(dict_file):
            dict_map[i] = word.strip()
    print("Loaded dict from %s" % args.dict)

# load init model
model_dir = args.model_path
model_file_name = None
if not os.path.isdir(args.model_path):
    model_dir = os.path.dirname(args.model_path)
    model_file_name = os.path.basename(args.model_path)
fluid.io.load_params(exe, dirname=model_dir, filename=model_file_name)
print("Init model from: %s." % args.model_path)

def img_display(path):
    cmd = ('xdg-open %s' % path)
    os.system(cmd)

def preprocess(path):
	watch_cascade = cv2.CascadeClassifier('./cascade.xml')
	image = cv2.imread(path)
	resize_h = 1000
	height = image.shape[0]
	scale = image.shape[1]/float(image.shape[0])
	raw = image
	image = cv2.resize(image, (int(scale*resize_h), resize_h))
	image_gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
	watches = watch_cascade.detectMultiScale(image_gray, 1.1, 3, minSize=(36, 9), maxSize=(36*40, 9*40))
	for (x,y,w,h) in watches:
		cv2.rectangle(image, (x, y), (x+w, y+h), (0, 255, 0), 2)
	cv2.namedWindow('image', cv2.WINDOW_GUI_NORMAL)
	cropImg = image[y:y+h, x:x+w]
	cv2.imshow('result', cropImg)
	cv2.moveWindow('result', 0, 0)
	cropImg = cv2.resize(cropImg, (185, 48))
	new_path = ("new.png")
	cv2.imwrite(new_path, cropImg)
	cv2.imshow("image", image)
	cv2.moveWindow('image', 0, 0)
	cv2.waitKey(0)
	cv2.destroyAllWindows()

def inference(args):
    batch_times = []
    iters = 0
    plate_str = ''
    for data in infer_reader():
        feed_dict = get_feeder_data(data, place, need_label=False)
        if args.iterations > 0 and iters == args.iterations + args.skip_batch_num:
            break
        if iters < args.skip_batch_num:
            print("Warm-up itaration")
        if iters == args.skip_batch_num:
            profiler.reset_profiler()

        start = time.time()
        result = exe.run(fluid.default_main_program(),
                         feed=feed_dict,
                         fetch_list=[ids],
                         return_numpy=False)
        indexes = np.array(result[0]).flatten()
        batch_time = time.time() - start
        fps = args.batch_size / batch_time
        batch_times.append(batch_time)
        dict_map = {0: "京", 1: "沪", 2: "沪", 3: "渝", 4: "冀", 5: "晋", 6: "蒙", 7: "辽", 8: "吉", 
        9: "黑", 10: "苏", 11: "浙", 12: "皖", 13: "闽", 14: "赣", 15: "鲁", 16: "豫", 17: "鄂", 
        18: "湘", 19: "粤", 20: "桂", 21: "琼", 22: "川", 23: "贵", 24: "云", 25: "藏", 26: "陕", 
        27: "甘", 28: "青", 29: "宁", 30: "新", 31: "0", 32: "1", 33: "2", 34: "3", 35: "4", 
        36: "5", 37: "6", 38: "7", 39: "8", 40: "9", 41: "A", 42: "B", 43: "C", 44: "D", 45: "E", 
        46: "F", 47: "G", 48: "H", 49: "J", 50: "K", 51: "L", 52: "M", 53: "N", 54: "P", 55: "Q", 
        56: "R", 57: "S", 58: "T", 59: "U", 60: "V", 61: "W", 62: "X", 63: "Y", 64: "Z"}
        if dict_map is not None:    
            print('')
            for index in indexes:
            	plate_str += dict_map[index]
        else:
            print('')

        iters += 1

    latencies = batch_times[args.skip_batch_num:]
    latency_avg = np.average(latencies)
    latency_pc99 = np.percentile(latencies, 99)
    fpses = np.divide(args.batch_size, latencies)
    fps_avg = np.average(fpses)
    fps_pc99 = np.percentile(fpses, 1)

    return plate_str

def prune(words, sos, eos):
    """Remove unused tokens in prediction result."""
    start_index = 0
    end_index = len(words)
    if sos in words:
        start_index = np.where(words == sos)[0][0] + 1
    if eos in words:
        end_index = np.where(words == eos)[0][0]
    return words[start_index:end_index]

def ocr_func():
    if args.profile:
        if args.use_gpu:
            with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
                return inference(args)
        else:
            with profiler.profiler("CPU", sorted_key='total') as cpuprof:
                return inference(args)
    else:
        return inference(args)

# 图形界面设计
window = tk.Tk()
window.title('车牌识别')
window.geometry('280x120')

lb1 = tk.Label(window, text='原始图片路径')
lb1.place(x=15, y=15)

e1 = tk.Entry(window, width=35, show=None)
e1.place(x=15, y=35)

lb2 = tk.Label(window, text='识别结果')
lb2.place(x=15, y=70)

t1 = tk.Text(window, width=35, height=1)
t1.place(x=15, y=90)

def start_rec():
    filePath = e1.get()
    preprocess(filePath)
    result = ocr_func()
    t1.delete('1.0', 'end')
    t1.insert('insert', result)
    img_display(filePath)

b1 = tk.Button(window, text='开始识别', width=10,
        height=1, command=start_rec)
b1.place(x=167, y=10)

window.mainloop()
